{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import Counter\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_table(r'D:\\csdn-python\\第十六周 数据挖掘与机器学习进阶（上）\\05-发给学员\\data\\iris\\x.txt', engine='python', header=None, sep=' ')\n",
    "X_train = df.values[:,2:]\n",
    "# pandas读的文件是DataFarme，需要values一下才能转换成array\n",
    "y = pd.read_table(r'D:\\csdn-python\\第十六周 数据挖掘与机器学习进阶（上）\\05-发给学员\\data\\iris\\y.txt', engine='python', header=None, sep=' ')\n",
    "y_train=y.values[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAGmNJREFUeJzt3X+MZWV9x/HPd+6dqR1RIO6kwsLO0BZNd1HUnVAMxtDO2vBLoK1RyFoLaqbMSCvRplU3wUCyiU2qxaizdIRZwJliW1ALgm1ltUHjjzq7RXFZTJCysEJlxHaRjunuzn77x7nLztw5d+5z7zn3nh/3/Upuds4zzznne9fw9ezzfM/zmLsLAFAufVkHAABIH8kdAEqI5A4AJURyB4ASIrkDQAmR3AGghEjuAFBCJHcAKCGSOwCUUDWrG69bt85HRkayuj0AFNLu3bt/5u5DzfplltxHRkY0Pz+f1e0BoJDMbH9IP4ZlAKCESO4AUEIkdwAoIZI7AJQQyR0ASojkDgAlRHIHgBJqmtzN7HQz+7qZ7TOzvWb2/pg+55vZQTN7qPa5vjPhAiiyuTlpZETq64v+nJtLfn7Sa5ZVyEtMRyR90N33mNnLJO02s6+6+yN1/b7h7pekHyKAMpibk8bHpcXF6Hj//uhYkrZube/8q6+WzKRDh9q7Zpk1fXJ392fcfU/t519I2idpfacDA1Au27YdT8zHLC5G7e2ef/jw8cTezjXLrKUxdzMbkfR6Sd+N+fUbzez7ZvYVM9vU4PxxM5s3s/mFhYWWgwVQXE8+2Vp7u/1a7VtWwcndzE6QdLek69z9+bpf75E07O5nS/qUpC/FXcPdp9191N1Hh4aarnsDoEQ2bGitvd1+rfYtq6Dkbmb9ihL7nLt/of737v68u79Q+/l+Sf1mti7VSAEU2vbt0uDgyrbBwai93fP7+6WBgfavWWYh1TIm6VZJ+9z9Ew36vLLWT2Z2Tu26z6UZKIBi27pVmp6WhoejSdDh4eg4dOIz7vydO6WZmfavWWbm7mt3MHuTpG9IeljS0VrzRyRtkCR3v9nMrpU0oaiy5peSPuDu31rruqOjo86SvwDQGjPb7e6jzfqFVMt8093N3V/r7q+rfe5395vd/eZan0+7+yZ3P9vdz22W2AFkqxu14ZOTUrUaPVFXq9ExuiezzToAZCNpvXmIyUlpx47jx0tLx4+nptK5B9bWdFimUxiWAbIxMhIl9HrDw9ITT6Rzj2o1Suj1KhXpyJF07tGrUhuWAVAuSevNQ8Ql9rXakT6SO9Bjktabh6hUWmtH+kjuQI9JWm8e4tgYfmg70kdyB3pM0nrzEFNT0sTE8Sf1SiU6ZjK1e5hQBYACYUIVAHoYyR3oQaGbXqS9OUYr54b2LcJmHZnE6O6ZfDZv3uwAum921n1w0F06/unvdx8YWNk2MBC1N+s3OBhds537Njo3tG8r18xK2jFKmveAHMuYO9BjGr3ElETIC1CtvDwV2rcbL2QllXaMoWPuJHegx/T1Rc+PaTKTjh5du0+j+8adG9q3lWtmJe0YmVAFEKsTG1mEXLOVl6dC+3bjhayksoqR5A70mNBNLwYGovZm/UJfgGrl5anQvt14ISupzGIMGZjvxIcJVSA7s7Puw8PuZtGfs7PJ2pLcN2nfJPF0S5oxiglVACgfxtwBZKpMdepx8h43m3UASF3ohiDd2DikE4oQN8MyAFJXpjr1OFnGzbAMgMyEbgjSjY1DOqEIcZPcAaSuTHXqcYoQN8kdQOrKVKcepwhxk9wBpC50Q5BubBzSCUWImwlVACgQJlQBdES31oLPUx15nmIJFvIaayc+LD8AFE/c2uSh677H9Uu6nntW3znLNePF8gMA0tatteDzVP+ep1gkhmUAdEAn6rjjrpmnOvI8xdIKkjuAYN1aCz5PdeR5iqUVJHcAweLqu0PXfY/rl3Q9927IUyytILkDCBZX3z0zI+3cubJt586ovVm/RrXheaojz1MsrWg6oWpmp0u6Q9IrJR2VNO3un6zrY5I+KekiSYuSrnL3PWtdlwlVAGhdmhOqRyR90N1/S9K5kt5nZhvr+lwo6czaZ1zSjhbjBZBQo1rsTtSgl0WS75z7v6+QesnlH0n/JOktdW1/K+nKZcc/knTKWtehzh1IT6Na7ImJ9uvSs6zl7oYk9etZ1r6rE3XuZjYi6UFJZ7n788vavyzpY+7+zdrxLkl/6e4Nx10YlgHS06gWu1KRlpbav27e11VPIkn9eqnWczezEyTdLem65Yn92K9jTln1/xpmNm5m82Y2v7CwEHprAE00qrlOktjXum4ZJKlfL0Lte1ByN7N+RYl9zt2/ENPlgKTTlx2fJunp+k7uPu3uo+4+OjQ01E68AGI0qrmuVDpz3TJIUr9ehNr3psm9Vglzq6R97v6JBt3ukfQui5wr6aC7P5NinADW0KgWe3y8/br0ItRyJ5Gkfr0Qte/NBuUlvUnREMsPJD1U+1wk6RpJ19T6mKTPSPqxpIcljTa7LhOqQLpmZ92Hh93Noj+PTe7FtYe2lV2S75zV35dYOAwAyoeFwwBIiq/HnpyUqtXojctqNToOPTdvihBjFqpZBwCgc+bmonH3xcXoeP9+6aqrpCNHjvdZWpJ21F47nJpa+9zx8ejnvLx6X4QYs8KwDFBiray/XqmsTPp5W8c8ThFiTBvDMgBaqruur4kvQi13EWLMCskdKLFW6q7ra+KLUMtdhBizQnIHSiyuHrvaYKbt2Fj1WufmrZa7CDFmheQOlFjcWuS33SZNTBx/Uq9UouPlk6mNzs3bOuZFiDErTKgCQIEwoQqU2JZ3PCqrHJGZyypHtOUdjwbXrkvp14bH3Tv0Hq3EUur119MW8hprJz4sPwC0Z+zt+1w6umIt8ei4vi1az71e2muRT0z4qvtK7n19ze/RSixFXX89bWL5AaCcrHJEOhr2/mF97bqUfm14tRq+tHD9PVqJpajrr6ctdFiG5A4UjJkrfguFePX/iff1rW6LrisdPdpOPK31XX6PVmJJEnfa3zlLjLkDZdUXvgNH3HruadeGt7JmfP09Woml7Ouvp43kDhTM2Nse0+qNzjymbXXtupR+bXjcPaToabnZPVqJpfTrr6ctZGC+Ex8mVIH2jb19n6vvcDSJ2nfYx96+zycm3CuVaLKwUomfTD0m7bXI4+4deo9WYini+utpExOqAFA+jLkDOdetuuu5h+c0ctOI+m7o08hNI5p7uOwF3pBYzx3IRLfWIZ97eE7j945r8XB0o/0H92v83uhGW1/DO/plxrAMkIFu1V2P3DSi/QdX32j4xGE9cV2KN0LXMCwD5Fi31iF/8mD8BRu1ozxI7kAGulV3veHE+As2akd5kNyBDHSr7nr72HYN9q+80WD/oLaPlbnAGxLJHchEt9Yh3/qarZp+67SGTxyWyTR84rCm3zrNZGoPYEIVAAqECVUg50LrzztRp56n2veeW2e9S6hzBzIQWn/eiTr1PNW+d6vevxcxLANkILT+vBN16nmqfS/TOuvdwrAMkGOh9eedqFPPU+17t+r9exHJHchAaP15J+rU81T73ovrrHcLyR3IQGj9eSfq1PNU+96T66x3CckdyEBo/Xkn6tTzVPverXr/XsSEKgAUSGoTqmY2Y2bPmtkPG/z+fDM7aGYP1T7XtxMwUERJ6sXXf3y97AZ78bP+4+tjrxd6j8n7JlW9sSq7wVS9sarJ+yajGGPqyKktL7+mT+5m9mZJL0i6w93Pivn9+ZL+3N0vaeXGPLmj6OrrxaVo7DpkiGP9x9fr6ReebnqP/r5+mZkOLR1a8x6T901qx/yOVeeP/eJWfXvHu1+sI5ek/v5oCOTQ8UtqcJDhkKJI7cnd3R+U9PNUogJKZNuubSsSuyQtHl7Utl3bmp4bktgl6fDRwysSe6N7TO+ejj1/1y1jKxK7JB0+vDKxS9FLRNuah40CSWtC9Y1m9n0z+4qZbWrUyczGzWzezOYXFhZSujWQjSzrxevvseRL8R0Pnh5+TWrLSyWN5L5H0rC7ny3pU5K+1Kiju0+7+6i7jw4NDaVwayA7WdaL19+jYpX4jic+FX5NastLJXFyd/fn3f2F2s/3S+o3s3WJIwNyLkm9+KknnBp0j/6+fg1UBpreY3zzeOz5Y+/dtaqOvL9fGlh5SWrLSyhxcjezV5qZ1X4+p3bN55JeF8i7JPXiP/ngT1Yl+FNPOFWzfzC74no7L9+pmctmmt5j6uIpTYxOvPgEX7GKJkYn9MBfv3tVHfnOndLMDLXlZRdSLXOnpPMlrZP0U0kfldQvSe5+s5ldK2lC0hFJv5T0AXf/VrMbUy0DAK1Ls1rmSnc/xd373f00d7/V3W9295trv/+0u29y97Pd/dyQxA7kRbfWNW9Ugx4ST9y5eVqPvRFq6bPFG6roWUnq1FvRqAZ9YnRCUxdPrRlPta+qI0ePrDq3vr0TcSdRv067RC19WkKf3Enu6FndWte8emM1tlSxYhUduf54gm4UT6gs1mNvhHXaO4f13IEmulWn3qgGvb496X2zWI+9EdZpzx7JHT2rW3XqjWrQ69uT3jeL9dgbYZ327JHc0bO6ta55oxr0+va4eKp98dsc17dntR57I6zTnj2SO3pWt9Y1b1SDvnwytVE8t11+W+y5t11+Wy7WY2+Eddqzx4QqABQIE6oA0MPiB/SAApl7eE7bdm3Tkwef1IYTN2j72PZEQxRb7tiiXf+568XjsTPG9KpXvErTu6e15EuqWEXjm8c1dfGUJu+bXNUuaVXbeRvOWxWjpKC2PA23oDgYlkGhpf0iUn1iX8vGdRv1yM8eCepbscqK0seByoDcXYePHn6xLXRjDvQ2XmJCT0j7RSS7wVKIKl15ejkJ2WPMHT0hyw0zuqVM3wXdQ3JHoWW5YUa3lOm7oHtI7ii0tF9EGjtjLLjvxnUbg/vWv406UBlQf1//irbQjTmAECR3FFraLyI98K4HViX4sTPGYl8k2vu+vbHtcW23//7tK2KcuWxGOy/f2dbGHEAIJlQBoECYUAUCxG160cpGGKF9k2yuUYSNOZA/PLmjZ8XVyLdSax5aY5+kFr9bG4qgOKhzB5poZXOMuFrz0Br7JLX43dpQBMXBsAzQRCv143F9Q2vsk9Ti90IdPzqD5I6e1Ur9eFzf0Br7JLX4vVDHj84guaNnxdXIt1JrHlpjn6QWv1sbiqB8SO7oWXE18q3UmofW2Cepxe/WhiIoHyZUAaBAmFBFVxWhFjtpTTtQJDy5I7Ei1GLHxRi3pnre4gbq8eSOrtm2a9uKpClJi4cXtW3XtowiWi0uxkNLh1Ykdil/cQPtIrkjsSLUYietaQeKhuSOxIpQi520ph0oGpI7EitCLXZcjHFrquctbqBdJHckVoRa7LgY49ZUz1vcQLuaVsuY2YykSyQ96+5nxfzeJH1S0kWSFiVd5e57mt2YahkAaF2a1TK3Sbpgjd9fKOnM2mdc0o6QAIHlJu+bVPXGquwGU/XGqibvm0zUL+3106mHR9FUm3Vw9wfNbGSNLpdJusOjfwJ8x8xOMrNT3P2ZlGJEyU3eN6kd88efCZZ86cXjqYunWu5XX9O+/+B+jd87Lkktr5++/+B+Xf2lq1es8d7K9YCspDHmvl7SU8uOD9TagCDTu6eD2kP7Jam7jzv38NHDKzbvaOV6QFbSSO4W0xY7kG9m42Y2b2bzCwsLKdwaZbDkS0Htof06sX560r5At6WR3A9IOn3Z8WmSno7r6O7T7j7q7qNDQ0Mp3BplULFKUHtov06sn560L9BtaST3eyS9yyLnSjrIeDtaMb55PKg9tF/a66e3ssY7kBdNk7uZ3Snp25JebWYHzOw9ZnaNmV1T63K/pMclPSbps5LiyxeABqYuntLE6MSLT+AVq2hidGLFJGkr/dJeP72VNd6BvGBVSAAoEFaFBIAeRnIHgBIiuQNACZHcAaCESO4AUEIkdwAoIZI7AJQQyR0ASojkDgAlRHIHgBIiuQNACZHcAaCESO4AUEIkdwAoIZI7AJQQyR0ASojkDgAlRHIHgBIiuQNACZHcAaCESO4AUEIkdwAoIZI7AJQQyR0ASojk3oq5OWlkROrri/6cm8s6IgCIVc06gMKYm5PGx6XFxeh4//7oWJK2bs0uLgCIwZN7qG3bjif2YxYXo3YAyBmSe6gnn2ytHQAyRHIPtWFDa+0AkCGSe6jt26XBwZVtg4NROwDkDMk91Nat0vS0NDwsmUV/Tk8zmQogl6iWacXWrSRzAIUQ9ORuZheY2Y/M7DEz+1DM768yswUze6j2eW/6oeYUte8Acqjpk7uZVSR9RtJbJB2Q9D0zu8fdH6nr+vfufm0HYswvat8B5FTIk/s5kh5z98fd/ZCkz0u6rLNhFQS17wByKiS5r5f01LLjA7W2en9oZj8ws7vM7PS4C5nZuJnNm9n8wsJCG+HmDLXvAHIqJLlbTJvXHd8racTdXyvpAUm3x13I3afdfdTdR4eGhlqLNI+ofQeQUyHJ/YCk5U/ip0l6enkHd3/O3f+vdvhZSZvTCS/nqH0HkFMhyf17ks40szPMbEDSFZLuWd7BzE5ZdnippH3phZhj1L4DyKmm1TLufsTMrpX0L5Iqkmbcfa+Z3Shp3t3vkfRnZnappCOSfi7pqg7GnC/UvgPIoaA6d3e/391f5e6/4e7ba23X1xK73P3D7r7J3c92999x90c7GXTbQmvSt2yJnsSPfbZsaXx+6DWphwfQTe6eyWfz5s3eVbOz7oOD7tLxz+Bg1L7c2NjKPsc+GzeuPn9gwL2/v/k1Q+8NAE0oGjFpmmMt6tt9o6OjPj8/370bjoxELxnVGx6Wnnji+LHFFQe1qP6aofcGgCbMbLe7jzbr1zsLh3WzJr3+mtTDA+iy3knu3axJr78m9fAAuqx3kntoTfrYWPz5GzeuPn9gQOrvb35N6uEBdFnvJPfQmvQHHlid4MfGpL17V58/MyPt3Nn8mtTDA+iy3plQBYASYEIVAHpYbyX3yUmpWo2GRqrV6DjuhaVWXjji5SQAOdQ7wzKTk9KOHWF9zaJXjY4ZHIwfI6/frGOtvgCQgtBhmd5J7tWqtLTU/vlxLxzxchKALmPMvV6SxC7Fv3DEy0kAcqp3knulkuz8uBeOeDkJQE71TnI/tnF1iPr1ZRq9cMTLSQByqneS+9SUNDFx/Am+UomO415Y+tznwl444uUkADnVOxOqAFAC5Z1QDa0rj6tp37RpZU37pk3R+jDL2wYGpJNPXtl28snRNdevX9m+fj2bdQDIp5BF3zvxaWuzjtBNLyYm4jfc6MaHzToAdJBKuVlHaF150pr2pNisA0CHlHNYJrSuPMvELrFZB4DMFSu5h9aVJ61pT4rNOgBkrFjJPbSuvJWa9rSxWQeAHChWcg+tK29U075x48p+Gzeu3kmpv1866aSVbSedFE2DnnrqyvZTT5VmZ9msA0DuFGtCFQB6XDknVBtJUkMed25cPTwAFEg16wASq19Tff/+42PuzYY94s595ztX93vkkSjB792bXtwA0EHFH5ZJUkPe6NxGMvq7AoBjemdYJkkNOXXmAEqq+Mk9SQ05deYASqr4yT1JDXncuY3Ul1ECQI4VP7knqSGPO3d2Nr4enslUAAUSNKFqZhdI+qSkiqRb3P1jdb//FUl3SNos6TlJ73D3J9a6JnXuANC61CZUzawi6TOSLpS0UdKVZlY/RvEeSf/t7r8p6W8k/VXrIQMA0hIyLHOOpMfc/XF3PyTp85Iuq+tzmaTbaz/fJWnMrH4jUgBAt4Qk9/WSnlp2fKDWFtvH3Y9IOijpFWkECABoXUhyj3sCrx+oD+kjMxs3s3kzm19YWAiJDwDQhpDkfkDS6cuOT5P0dKM+ZlaVdKKkn9dfyN2n3X3U3UeHhobaixgA0FRIcv+epDPN7AwzG5B0haR76vrcI+mPaz+/TdLXPKt1DQAAwaWQF0m6SVEp5Iy7bzezGxVt1HqPmb1E0uckvV7RE/sV7v54k2suSGphYZdV1kn6WYLz84Tvkk9l+i5Sub5PL3+XYXdvOvSR2cJhSZnZfEitZxHwXfKpTN9FKtf34bs0V/w3VAEAq5DcAaCEipzcp7MOIEV8l3wq03eRyvV9+C5NFHbMHQDQWJGf3AEADRQuuZvZjJk9a2Y/zDqWpMzsdDP7upntM7O9Zvb+rGNql5m9xMz+3cy+X/suN2QdU1JmVjGz/zCzL2cdSxJm9oSZPWxmD5lZoZdiNbOTzOwuM3u09t/NG7OOqR1m9ura/x7HPs+b2XWp3qNowzJm9mZJL0i6w93PyjqeJMzsFEmnuPseM3uZpN2SLnf3RzIOrWW1heJe6u4vmFm/pG9Ker+7fyfj0NpmZh+QNCrp5e5+SdbxtMvMnpA06u6Frws3s9slfcPdb6m9VDno7v+TdVxJ1Fbe/Ymk33b3JO/+rFC4J3d3f1AxSxsUkbs/4+57aj//QtI+rV6UrRA88kLtsL/2KdaTwzJmdpqkiyXdknUsiJjZyyW9WdKtkuTuh4qe2GvGJP04zcQuFTC5l5WZjSh6w/e72UbSvtowxkOSnpX0VXcv7HdR9Eb2X0g6mnUgKXBJ/2pmu81sPOtgEvh1SQuSdtaGy24xs5dmHVQKrpB0Z9oXJbnngJmdIOluSde5+/NZx9Mud19y99cpWlzuHDMr5LCZmV0i6Vl33511LCk5z93foGjDnffVhjaLqCrpDZJ2uPvrJf2vpA9lG1IytaGlSyX9Y9rXJrlnrDY+fbekOXf/QtbxpKH2T+V/k3RBxqG06zxJl9bGqj8v6XfNbDbbkNrn7k/X/nxW0hcVbcBTRAckHVj2L8K7FCX7IrtQ0h53/2naFya5Z6g2CXmrpH3u/oms40nCzIbM7KTaz78qaYukR7ONqj3u/mF3P83dRxT9k/lr7v7OjMNqi5m9tDZZr9oQxu9JKmSlmbv/l6SnzOzVtaYxSYUrPqhzpTowJCNF/8wpFDO7U9L5ktaZ2QFJH3X3W7ONqm3nSfojSQ/Xxqol6SPufn+GMbXrFEm312b++yT9g7sXuoSwJH5N0hdru15WJf2du/9ztiEl8qeS5mrDGY9LujrjeNpmZoOS3iLpTzpy/aKVQgIAmmNYBgBKiOQOACVEcgeAEiK5A0AJkdwBoIRI7gBQQiR3ACghkjsAlND/A16qVcpcQdy0AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1],color='r') # 绘图第一类数据\n",
    "plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1],color='g') # 绘图第二类数据\n",
    "plt.scatter(X_train[y_train==2,0],X_train[y_train==2,1],color='b') # 绘图第三类数据\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gini(y):\n",
    "    '''基尼系数计算公式'''\n",
    "    counter = Counter(y)\n",
    "    result = 0\n",
    "    for v in counter.values():\n",
    "        result += (v / len(y)) ** 2\n",
    "    return 1 - result\n",
    "\n",
    "def cut(X, y, d, v):\n",
    "    '''把数据集左右分类'''\n",
    "    ind_left = (X[:, d] <= v)\n",
    "    ind_right = (X[:, d] > v)\n",
    "    return X[ind_left], X[ind_right], y[ind_left], y[ind_right]\n",
    "\n",
    "def try_split(X, y):\n",
    "    \n",
    "    best_g = 1\n",
    "    best_d = -1\n",
    "    best_v = -1\n",
    "    \n",
    "    for d in range(X.shape[1]):\n",
    "        sorted_index = np.argsort(X[:, d])\n",
    "        for i in range(len(X) - 1):\n",
    "            if X[sorted_index[i], d] == X[sorted_index[i + 1], d]:\n",
    "                continue\n",
    "            v = (X[sorted_index[i], d] + X[sorted_index[i + 1], d]) / 2\n",
    "            X_left, X_right, y_left, y_right = cut(X, y, d, v)\n",
    "            g_all = gini(y_left) + gini(y_right)\n",
    "#             print('d={} v={} g={}'.format(d, v, g_all))\n",
    "            \n",
    "            if g_all < best_g:\n",
    "                best_g = g_all\n",
    "                best_d = d\n",
    "                best_v = v\n",
    "    return best_d, best_v, best_g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 2.45, 0.5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "try_split(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node():\n",
    "    '''决策树节点'''\n",
    "    def __init__(self, d=None, v=None, g=None, l=None):\n",
    "        self.dim = d\n",
    "        self.value = v\n",
    "        self.gini = g\n",
    "        self.label = l\n",
    "        \n",
    "        self.children_left = None\n",
    "        self.children_right = None\n",
    "        \n",
    "    def __repr__(self):\n",
    "        return 'Node(d={},v={},g={},l={})'.format(self.dim, self.value, self.gini, self.label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_tree(X, y):\n",
    "    d, v, g = try_split(X, y)\n",
    "    \n",
    "    # 如果是叶子节点，则返回None\n",
    "    if d == -1 or g == 0:\n",
    "        return None\n",
    "    \n",
    "    node = Node(d, v, g)\n",
    "    \n",
    "    # 左右子节点进行分叉\n",
    "    X_left, X_right, y_left, y_right = cut(X, y, d, v)\n",
    "    # 把左子节点进行递归，再次计算\n",
    "    node.children_left = create_tree(X_left, y_left)\n",
    "    # 若左子节点为None，说明是叶子节点，无需再进行分叉，打印出分类标签即可结束\n",
    "    if node.children_left is None:\n",
    "        label = Counter(y_left).most_common(1)[0][0]\n",
    "        node.children_left = Node(l=label)\n",
    "        \n",
    "    node.children_right = create_tree(X_right, y_right)    \n",
    "    if node.children_right is None:\n",
    "        label = Counter(y_right).most_common(1)[0][0]\n",
    "        node.children_right = Node(l=label)\n",
    "    \n",
    "    return node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Node(d=0,v=2.45,g=0.5,l=None)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree = create_tree(X_train, y_train)\n",
    "# 根节点\n",
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Node(d=None,v=None,g=None,l=0.0)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree.children_left"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Node(d=1,v=1.75,g=0.21057149006459386,l=None)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree.children_right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Node(d=0,v=5.35,g=0.10872781065088755,l=None)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree.children_right.children_left"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_tree(node):\n",
    "    '''按照graohviz需要的格式，封装整颗决策树'''\n",
    "    if node is None:\n",
    "        return ''\n",
    "    # 封装根节点\n",
    "    result = '{} [label=\"{}\"]\\n'.format(id(node), node)\n",
    "    # 递归封装左子节点\n",
    "    if node.children_left is not None:\n",
    "        result += '{} [label=\"{}\"]\\n'.format(id(node.children_left), node.children_left)\n",
    "        result += '\"{}\" -> \"{}\"\\n'.format(id(node), id(node.children_left))\n",
    "        result += show_tree(node.children_left)\n",
    "    # 递归封装右子节点    \n",
    "    if node.children_right is not None:\n",
    "        result += '{} [label=\"{}\"]\\n'.format(id(node.children_right), node.children_right)\n",
    "        result += '\"{}\" -> \"{}\"\\n'.format(id(node), id(node.children_right))\n",
    "        result += show_tree(node.children_right)\n",
    "        \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 把决策树结构保存成文件\n",
    "with open('temp.dot', 'w') as f:\n",
    "    f.write('digraph{\\n' + show_tree(tree) + '}\\n')\n",
    "# E:\\graphviz\\release\\bin\\dot.exe temp.dot -Tpng -o temp.png"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "# 利用graphviz工具进行决策树绘图\n",
    "os.system(r'E:\\graphviz\\release\\bin\\dot.exe temp.dot -Tpng -o temp1.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(x, node):\n",
    "    '''决策树预测分类'''\n",
    "    if node.label is not None:\n",
    "        return node.label\n",
    "    \n",
    "    if x[node.dim] <= node.value:\n",
    "        # left\n",
    "        return predict(x, node.children_left)\n",
    "    else:\n",
    "        # right\n",
    "        return predict(x, node.children_right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict(X_train[100], tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Node(d=0,v=2.45,g=0.5,l=None)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 利用封装好的决策树算法进行预测分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ML.tree.DecisionTreeClassifier at 0x241912786a0>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ML.tree import DecisionTreeClassifier\n",
    "\n",
    "X = np.loadtxt(r'D:\\csdn-python\\第十六周 数据挖掘与机器学习进阶（上）\\05-发给学员\\data\\iris\\x.txt')\n",
    "y = np.loadtxt(r'D:\\csdn-python\\第十六周 数据挖掘与机器学习进阶（上）\\05-发给学员\\data\\iris\\y.txt')\n",
    "\n",
    "dt_clf = DecisionTreeClassifier()\n",
    "dt_clf.fit(X, y) # 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Node(d=2,v=2.45,g=0.5,l=None)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt_clf.tree_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0.])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt_clf.predict(X[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.98"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ML.metrics import accuracy_score\n",
    "\n",
    "y_predict = dt_clf.predict(X)\n",
    "accuracy_score(y, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.98"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt_clf.score(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 利用封装好的kNN算法进行预测分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ML.knn.KNeighborsClassifier at 0x2419159a320>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ML.knn import KNeighborsClassifier\n",
    "\n",
    "knn_clf = KNeighborsClassifier()\n",
    "knn_clf.fit(X, y) # 伪训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9666666666666667"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf.score(X, y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
